{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM 详解代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data - IMDB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 小提示： 每次重新运行时请尽量安装模块进行\n",
    "# 比如完整的重新运行 MyLSTM以下所有内容\n",
    "# 单独重新运行一个 cell，可能出现意料之外的情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 借助 torchtext 加载 IMDB 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset import\n",
    "# !pip install torchtext torchdata\n",
    "from torchtext.datasets import IMDB\n",
    "from torchtext.datasets.imdb import NUM_LINES\n",
    "from torchtext.data import get_tokenizer\n",
    "from torchtext.vocab import build_vocab_from_iterator\n",
    "from torchtext.data.functional import to_map_style_dataset\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import utils\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# log 以及工具\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "import sys\n",
    "import logging\n",
    "logging.basicConfig(\n",
    "    level=logging.WARN, stream=sys.stdout, \\\n",
    "    format=\"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s\")\n",
    "\n",
    "# 设备 无显卡会被设置为 cpu\n",
    "device = 'cuda'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed=42):\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "\n",
    "seed_everything(1998)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_iter = IMDB(root=\"./data\", split=\"train\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对输入文本进行分割，返回值是 一个数组\n",
    "def yeild_tokens(train_data_iter, tokenizer):\n",
    "    for i, sample in enumerate(train_data_iter):\n",
    "        label, comment = sample\n",
    "        # 打开 cell 4 中的注释时请注意切换此处注释\n",
    "        yield tokenizer(comment) \n",
    "        # return tokenizer(comment)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 请打开下方注释理解 yeild_tokens \n",
    "# x = yeild_tokens(train_data_iter, tokenizer)\n",
    "# x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "单词表大小: 13351\n"
     ]
    }
   ],
   "source": [
    "# 分词、构建词表 这里第一次运行可能需要较长时间\n",
    "tokenizer = get_tokenizer(\"basic_english\")\n",
    "# 只使用出现次数大于20的token\n",
    "vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=[\"<unk>\"])\n",
    "vocab.set_default_index(0)  # 特殊索引设置为0\n",
    "print(f'单词表大小: {len(vocab)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建词向量 -- 可以使用 word2vec 或者 glovec 等方式来替代随机生成\n",
    "embedding = nn.Embedding(len(vocab), 64)\n",
    "# a = torch.LongTensor([0])\n",
    "# a\n",
    "# embedding(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于 LSTM 这里考虑使用 batch 进行处理\n",
    "# 针对同一 batch 中 句子长度不同的情况\n",
    "# 每次都对长度不足的句子进行 padding\n",
    "# 另外将标记由 1、2 修改为 0、1\n",
    "def collate_fn(batch):\n",
    "    \"\"\"\n",
    "    对DataLoader所生成的mini-batch进行后处理\n",
    "    \"\"\"\n",
    "    # print(batch)\n",
    "    target = []\n",
    "    token_index = []\n",
    "    max_length = 0  # 最大的token长度\n",
    "    for i, (label, comment) in enumerate(batch):\n",
    "        tokens = tokenizer(comment)\n",
    "        # print(tokens)\n",
    "        # print(vocab(tokens))\n",
    "        token_index.append(vocab(tokens)) # 字符列表转换为索引列表\n",
    "        \n",
    "        # 确定最大的句子长度\n",
    "        if len(tokens) > max_length:\n",
    "            max_length = len(tokens)\n",
    "        # 设定目标 label 1标记为 0  2标记为 1\n",
    "        if label == 1:\n",
    "            target.append(0)\n",
    "        else:\n",
    "            target.append(1)\n",
    "    # print(token_index)\n",
    "    \n",
    "    # padding 到最长长度\n",
    "    token_index = [index + [0]*(max_length-len(index)) for index in token_index]\n",
    "    # 词向量化\n",
    "    token_index = embedding(torch.tensor(token_index).to(torch.int32))\n",
    "    # print(token_index.shape)\n",
    "    # one-hot接收长整形的数据，所以要转换为int64\n",
    "    return (torch.tensor(target).to(torch.int64),token_index )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 理解上方代码请打开下方注释，其中 batch_size >= 2 时 请注意最大长度填充。\n",
    "\n",
    "# for batch_index, (target, token_index) in enumerate(train_data_loader):\n",
    "#     print(batch_index)\n",
    "#     print(target)\n",
    "#     print(token_index)\n",
    "#     # (batch_size, seq_len, input_size)\n",
    "#     print(token_index.shape)\n",
    "#     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义超参数\n",
    "input_size = 64 #\n",
    "hidden_size = 128 \n",
    "num_layers = 1\n",
    "num_classes = 2\n",
    "batch_size = 32\n",
    "max_seq_len = 512\n",
    "learning_rate = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train Dataloader\n",
    "train_data_iter = IMDB(root=\"data\", split=\"train\")\n",
    "train_data_loader = torch.utils.data.DataLoader(\n",
    "    to_map_style_dataset(train_data_iter), batch_size=batch_size, collate_fn=collate_fn, shuffle=True)\n",
    "\n",
    "# Eval Dataloader\n",
    "eval_data_iter = IMDB(root=\"data\", split=\"test\")\n",
    "eval_data_loader = utils.data.DataLoader(\n",
    "    to_map_style_dataset(eval_data_iter), batch_size=batch_size, collate_fn=collate_fn)\n",
    "\n",
    "#to_map_style_dataset 可以自行百度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  Gate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "hideen_size_temp = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sigmoid()"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sigmoid = nn.Sigmoid()\n",
    "sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.],\n",
       "        [1., 1., 1.],\n",
       "        [1., 1., 1.]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hid = torch.ones(hideen_size_temp, hideen_size_temp)\n",
    "hid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=3, out_features=3, bias=True)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model =  nn.Linear(hideen_size_temp, hideen_size_temp)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2625,  0.3412, -0.5055],\n",
       "        [-0.2625,  0.3412, -0.5055]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mid_output = model(x)\n",
    "mid_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4347, 0.5845, 0.3762],\n",
       "        [0.4347, 0.5845, 0.3762]], grad_fn=<SigmoidBackward0>)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gate = sigmoid(mid_output)\n",
    "gate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1141,  0.1994, -0.1902],\n",
       "        [-0.1141,  0.1994, -0.1902]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_output = gate * mid_output\n",
    "final_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 再次定义超参数 避免干扰下方训练\n",
    "input_size = 64 #\n",
    "hidden_size = 128 \n",
    "num_layers = 1\n",
    "num_classes = 2\n",
    "batch_size = 32\n",
    "max_seq_len = 512\n",
    "learning_rate = 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MyLSTM"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "基本无需修改，可以看到效果尚可，相比官方实现略有欠缺，但是能证明是一个有效的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义基础模型\n",
    "class LSTM(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_layers, num_classes):\n",
    "        \"\"\"\n",
    "        args:\n",
    "            input_size: 输入大小\n",
    "            hidden_size: 隐藏层大小\n",
    "            num_layers: 几层的LSTM\n",
    "            num_classes: 最后输出的类别，在这个示例中，输出应该是 0 或者 1\n",
    "            \n",
    "        \"\"\"\n",
    "        super(LSTM, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        self.fc_i = nn.Linear(input_size + hidden_size, hidden_size)\n",
    "        self.fc_f = nn.Linear(input_size + hidden_size, hidden_size)\n",
    "        self.fc_g = nn.Linear(input_size + hidden_size, hidden_size)\n",
    "        self.fc_o = nn.Linear(input_size + hidden_size, hidden_size)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.tanh = nn.Tanh()\n",
    "        self.fc_out = nn.Linear(hidden_size, num_classes)\n",
    "    def forward(self, x):\n",
    "        # shape (batch_size, seq_len, input_size)\n",
    "        # print(x.shape)\n",
    "        h_t = torch.zeros(x.size(0), x.size(1), self.hidden_size).to(x.device)\n",
    "        c_t = torch.zeros(x.size(0), x.size(1), self.hidden_size).to(x.device)\n",
    "        # print(h_t.shape)\n",
    "        # print(c_t.shape)\n",
    "        combined = torch.cat((x, h_t), dim=2)\n",
    "        i_t = self.sigmoid(self.fc_i(combined))\n",
    "        f_t = self.sigmoid(self.fc_f(combined))\n",
    "        g_t = self.tanh(self.fc_g(combined))\n",
    "        o_t = self.sigmoid(self.fc_o(combined))\n",
    "        c_t = f_t * c_t + i_t * g_t\n",
    "        h_t = o_t * self.tanh(c_t)\n",
    "            \n",
    "#         print(x.shape)\n",
    "#         print(combined.shape)\n",
    "#         print(i_t.shape)\n",
    "#         print(f_t.shape)\n",
    "#         print(g_t.shape)\n",
    "#         print(o_t.shape)\n",
    "#         print(h_t.shape)\n",
    "        h_t = F.avg_pool2d(h_t, (h_t.shape[1],1)).squeeze()\n",
    "        out = self.fc_out(h_t)\n",
    "#         print(out.cpu().shape)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型参数数量： 99074\n"
     ]
    }
   ],
   "source": [
    "# 检查模型是否存在问题\n",
    "# 设置随机数种子以保证结果可重复\n",
    "torch.manual_seed(2023)\n",
    "\n",
    "# 生成测试数据\n",
    "x = torch.randn(batch_size, max_seq_len, input_size).to(device)\n",
    "y = torch.randint(0, num_classes, (batch_size,)).to(device)\n",
    "\n",
    "# 初始化模型\n",
    "model = LSTM(input_size, hidden_size, num_layers, num_classes)\n",
    "model.to(device)\n",
    "# 打印模型参数数量\n",
    "print(\"模型参数数量：\", sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "\n",
    "# 计算模型输出\n",
    "output = model(x)\n",
    "# output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTM(\n",
       "  (fc_i): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (fc_f): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (fc_g): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (fc_o): Linear(in_features=192, out_features=128, bias=True)\n",
       "  (sigmoid): Sigmoid()\n",
       "  (tanh): Tanh()\n",
       "  (fc_out): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看模型结构\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:03,079 (233408485:20) WARNING: epoch_index: 0, batch_index: 0, loss: 0.6886175274848938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "298it [00:07, 38.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:10,820 (233408485:20) WARNING: epoch_index: 0, batch_index: 300, loss: 0.6428654193878174\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "599it [00:15, 39.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:18,390 (233408485:20) WARNING: epoch_index: 0, batch_index: 600, loss: 0.48170384764671326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:20, 39.09it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:23,065 (233408485:20) WARNING: epoch_index: 1, batch_index: 0, loss: 0.44872817397117615\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:07, 37.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:30,612 (233408485:20) WARNING: epoch_index: 1, batch_index: 300, loss: 0.5459613800048828\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "600it [00:15, 34.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:38,373 (233408485:20) WARNING: epoch_index: 1, batch_index: 600, loss: 0.41060134768486023\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:20, 39.04it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:43,096 (233408485:20) WARNING: epoch_index: 2, batch_index: 0, loss: 0.25612199306488037\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "297it [00:07, 38.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:51,098 (233408485:20) WARNING: epoch_index: 2, batch_index: 300, loss: 0.33394932746887207\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "596it [00:15, 41.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:37:58,503 (233408485:20) WARNING: epoch_index: 2, batch_index: 600, loss: 0.4152831733226776\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:19, 39.31it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:02,998 (233408485:20) WARNING: epoch_index: 3, batch_index: 0, loss: 0.31165584921836853\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "296it [00:07, 41.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:10,474 (233408485:20) WARNING: epoch_index: 3, batch_index: 300, loss: 0.348602294921875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "600it [00:15, 35.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:18,298 (233408485:20) WARNING: epoch_index: 3, batch_index: 600, loss: 0.22047576308250427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:20, 38.53it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:23,300 (233408485:20) WARNING: epoch_index: 4, batch_index: 0, loss: 0.3187304735183716\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:07, 39.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:31,194 (233408485:20) WARNING: epoch_index: 4, batch_index: 300, loss: 0.17088855803012848\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "600it [00:15, 36.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:39,150 (233408485:20) WARNING: epoch_index: 4, batch_index: 600, loss: 0.4386771321296692\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:20, 38.19it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:43,787 (233408485:20) WARNING: epoch_index: 5, batch_index: 0, loss: 0.2063377946615219\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:07, 39.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:51,578 (233408485:20) WARNING: epoch_index: 5, batch_index: 300, loss: 0.3431491255760193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "600it [00:15, 41.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:38:59,081 (233408485:20) WARNING: epoch_index: 5, batch_index: 600, loss: 0.2605396807193756\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:19, 39.37it/s]\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "cur_epoch = -1\n",
    "for epoch_index in range(0,6):\n",
    "    num_batches = len(train_data_loader)\n",
    "    for batch_index, (target, token_index) in tqdm(enumerate(train_data_loader)):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        target = target.to(device)\n",
    "        token_index = token_index.to(device)\n",
    "        step = num_batches*(epoch_index) + batch_index + 1   \n",
    "        logits = model(token_index)\n",
    "        loss = F.nll_loss(F.log_softmax(logits,dim=-1), target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_index % 300 == 0:\n",
    "            cur_epoch = epoch_index\n",
    "            logging.warning(f\"epoch_index: {epoch_index}, batch_index: {batch_index}, loss: {loss}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:10, 71.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:14,595 (612167170:11) WARNING: eval_loss: 0.6957253813743591, eval_acc: 0.86988\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "total_acc_account = 0\n",
    "total_account = 0\n",
    "for eval_batch_index, (eval_target, eval_token_index) in tqdm(enumerate(eval_data_loader)):\n",
    "    eval_target = eval_target.to(device)\n",
    "    eval_token_index = eval_token_index.to(device)\n",
    "    total_account += eval_target.shape[0]\n",
    "    eval_logits = model(eval_token_index)\n",
    "    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()\n",
    "    eval_loss = F.nll_loss(F.log_softmax(eval_logits,dim=-1), eval_target)\n",
    "logging.warning(f\"eval_loss: {eval_loss}, eval_acc: {total_acc_account / total_account}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LSTM-Pytorch"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "可以修改训练时候的 epoch 数字，在当前代码下，3 训练不足， 6容易过拟合，但是可以辅助证明我们的实现没有问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_layers, num_classes):\n",
    "        \"\"\"\n",
    "        args:\n",
    "            input_size: 输入大小\n",
    "            hidden_size: 隐藏层大小\n",
    "            num_layers: 几层的LSTM\n",
    "            num_classes: 最后输出的类别，在这个示例中，输出应该是 0 或者 1\n",
    "            \n",
    "        \"\"\"\n",
    "        super(LSTM, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        self.lstm = nn.LSTM(self.input_size, self.hidden_size, num_layers=num_layers,batch_first=True,bidirectional=False)\n",
    "        self.fc_out = nn.Linear(hidden_size, num_classes)\n",
    "    def forward(self, x):\n",
    "        x,_ = self.lstm(x)\n",
    "        x = F.avg_pool2d(x, (x.shape[1],1)).squeeze()\n",
    "        out = self.fc_out(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型参数数量： 99586\n"
     ]
    }
   ],
   "source": [
    "# 检查模型是否存在问题\n",
    "# 设置随机数种子以保证结果可重复\n",
    "\n",
    "\n",
    "\n",
    "# 生成测试数据\n",
    "x = torch.randn(batch_size, max_seq_len, input_size).to(device)\n",
    "y = torch.randint(0, num_classes, (batch_size,)).to(device)\n",
    "\n",
    "# 初始化模型\n",
    "model = LSTM(input_size, hidden_size, num_layers, num_classes)\n",
    "model.to(device)\n",
    "# 打印模型参数数量\n",
    "print(\"模型参数数量：\", sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "\n",
    "# 计算模型输出\n",
    "output = model(x)\n",
    "# output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LSTM(\n",
       "  (lstm): LSTM(64, 128, batch_first=True)\n",
       "  (fc_out): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看模型结构\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:14,813 (3862857436:23) WARNING: epoch_index: 0, batch_index: 0, loss: 0.6966263055801392\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "299it [00:08, 38.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:23,082 (3862857436:23) WARNING: epoch_index: 0, batch_index: 300, loss: 0.5508784055709839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "599it [00:16, 36.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:31,275 (3862857436:23) WARNING: epoch_index: 0, batch_index: 600, loss: 0.5803604125976562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:21, 36.57it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:36,188 (3862857436:23) WARNING: epoch_index: 1, batch_index: 0, loss: 0.40404248237609863\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "298it [00:08, 34.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:44,230 (3862857436:23) WARNING: epoch_index: 1, batch_index: 300, loss: 0.28521886467933655\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "600it [00:16, 37.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:52,248 (3862857436:23) WARNING: epoch_index: 1, batch_index: 600, loss: 0.2420959770679474\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:20, 37.49it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:39:57,044 (3862857436:23) WARNING: epoch_index: 2, batch_index: 0, loss: 0.40587106347084045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "296it [00:08, 35.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:40:05,171 (3862857436:23) WARNING: epoch_index: 2, batch_index: 300, loss: 0.4198710024356842\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "599it [00:16, 37.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:40:13,358 (3862857436:23) WARNING: epoch_index: 2, batch_index: 600, loss: 0.33481597900390625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:21, 36.72it/s]\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "cur_epoch = -1\n",
    "for epoch_index in range(0,3):\n",
    "    num_batches = len(train_data_loader)\n",
    "    for batch_index, (target, token_index) in tqdm(enumerate(train_data_loader)):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        target = target.to(device)\n",
    "        token_index = token_index.to(device)\n",
    "        step = num_batches*(epoch_index) + batch_index + 1   \n",
    "        logits = model(token_index)】\n",
    "        loss = F.nll_loss(F.log_softmax(logits,dim=-1), target)\n",
    "        loss.backward()\n",
    "        # nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断，保证训练稳定\n",
    "        optimizer.step()\n",
    "        if batch_index % 300 == 0:\n",
    "            cur_epoch = epoch_index\n",
    "            logging.warning(f\"epoch_index: {epoch_index}, batch_index: {batch_index}, loss: {loss}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:10, 74.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:40:28,881 (612167170:11) WARNING: eval_loss: 0.6789069175720215, eval_acc: 0.8048\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "total_acc_account = 0\n",
    "total_account = 0\n",
    "for eval_batch_index, (eval_target, eval_token_index) in tqdm(enumerate(eval_data_loader)):\n",
    "    eval_target = eval_target.to(device)\n",
    "    eval_token_index = eval_token_index.to(device)\n",
    "    total_account += eval_target.shape[0]\n",
    "    eval_logits = model(eval_token_index)\n",
    "    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()\n",
    "    eval_loss = F.nll_loss(F.log_softmax(eval_logits,dim=-1), eval_target)\n",
    "logging.warning(f\"eval_loss: {eval_loss}, eval_acc: {total_acc_account / total_account}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RNN"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "IMDB 数据样本长度大概在 1000 左右，加之词向量是随机生成的，所以 RNN 模型不够稳定\n",
    "容易梯度爆炸和过拟合，，，可以考虑调低学习率或者重新运行\n",
    "可以手动调整一下学习率和训练 epoch，如果 loss 稳定在 0.69 或者出现 nan 则表示存在问题\n",
    "能低于 0.6 说明可以继续跑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, output_size, hidden_dim, n_layers):\n",
    "        super(RNN, self).__init__()\n",
    "\n",
    "        # Defining some parameters\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.n_layers = n_layers\n",
    "\n",
    "        #Defining the layers\n",
    "        # RNN Layer\n",
    "        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)   \n",
    "        # Fully connected layer\n",
    "        self.fc = nn.Linear(hidden_dim, output_size)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \n",
    "        batch_size = x.size(0)\n",
    "\n",
    "        # Initializing hidden state for first input using method defined below\n",
    "        hidden = self.init_hidden(batch_size)\n",
    "\n",
    "        # Passing in the input and hidden state into the model and obtaining outputs\n",
    "        out, hidden = self.rnn(x, hidden)\n",
    "        \n",
    "        # Reshaping the outputs such that it can be fit into the fully connected layer\n",
    "        #out = out.contiguous().view(-1, self.hidden_dim)\n",
    "        out = F.avg_pool2d(out, (out.shape[1],1)).squeeze()\n",
    "        out = self.fc(out)\n",
    "        \n",
    "        return out\n",
    "    \n",
    "    def init_hidden(self, batch_size):\n",
    "        # This method generates the first hidden state of zeros which we'll use in the forward pass\n",
    "        # We'll send the tensor holding the hidden state to the device we specified earlier as well\n",
    "        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)\n",
    "        return hidden\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型参数数量： 25090\n"
     ]
    }
   ],
   "source": [
    "# 生成测试数据\n",
    "x = torch.randn(batch_size, max_seq_len,input_size).to(device)\n",
    "y = torch.randint(0, num_classes, (batch_size,)).to(device)\n",
    "hidden = torch.zeros(batch_size, hidden_size).to(device)\n",
    "# 初始化模型\n",
    "model = RNN(input_size,  num_classes,hidden_size,1)\n",
    "\n",
    "model.to(device)\n",
    "# 打印模型参数数量\n",
    "print(\"模型参数数量：\", sum(p.numel() for p in model.parameters() if p.requires_grad))\n",
    "\n",
    "# 计算模型输出\n",
    "output = model(x)\n",
    "# output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RNN(\n",
       "  (rnn): RNN(64, 128, batch_first=True)\n",
       "  (fc): Linear(in_features=128, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:03,908 (1096697844:23) WARNING: epoch_index: 12, batch_index: 0, loss: 0.6250293254852295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:06, 50.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:10,319 (1096697844:23) WARNING: epoch_index: 12, batch_index: 300, loss: 0.6724219918251038\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "599it [00:12, 42.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:16,588 (1096697844:23) WARNING: epoch_index: 12, batch_index: 600, loss: 0.5623435378074646\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:16, 47.56it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:20,355 (1096697844:23) WARNING: epoch_index: 13, batch_index: 0, loss: 0.5873069167137146\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "295it [00:06, 49.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:26,896 (1096697844:23) WARNING: epoch_index: 13, batch_index: 300, loss: 0.5667327642440796\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "597it [00:12, 46.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:33,395 (1096697844:23) WARNING: epoch_index: 13, batch_index: 600, loss: 0.630550742149353\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:16, 47.39it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:36,858 (1096697844:23) WARNING: epoch_index: 14, batch_index: 0, loss: 0.5228468179702759\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "300it [00:06, 43.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:42,946 (1096697844:23) WARNING: epoch_index: 14, batch_index: 300, loss: 0.46919119358062744\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "598it [00:12, 44.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:49,330 (1096697844:23) WARNING: epoch_index: 14, batch_index: 600, loss: 0.6358293294906616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:16, 47.77it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:53,238 (1096697844:23) WARNING: epoch_index: 15, batch_index: 0, loss: 0.6251819729804993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "297it [00:06, 52.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 00:59:59,394 (1096697844:23) WARNING: epoch_index: 15, batch_index: 300, loss: 0.5249137282371521\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "597it [00:11, 49.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:05,039 (1096697844:23) WARNING: epoch_index: 15, batch_index: 600, loss: 0.6910139918327332\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:15, 51.67it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:08,375 (1096697844:23) WARNING: epoch_index: 16, batch_index: 0, loss: 0.5868373513221741\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "299it [00:05, 52.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:14,215 (1096697844:23) WARNING: epoch_index: 16, batch_index: 300, loss: 0.6238973736763\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "595it [00:11, 46.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:20,129 (1096697844:23) WARNING: epoch_index: 16, batch_index: 600, loss: 0.48795947432518005\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:15, 49.65it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:24,120 (1096697844:23) WARNING: epoch_index: 17, batch_index: 0, loss: 0.5158190131187439\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "296it [00:06, 48.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:30,500 (1096697844:23) WARNING: epoch_index: 17, batch_index: 300, loss: 0.7333455681800842\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "596it [00:11, 56.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:00:36,137 (1096697844:23) WARNING: epoch_index: 17, batch_index: 600, loss: 0.5788255333900452\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:15, 50.98it/s]\n"
     ]
    }
   ],
   "source": [
    "# lr 0.01 存在跑飞情况\n",
    "# 0.001 训练 18 epoch\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
    "cur_epoch = -1\n",
    "for epoch_index in range(12,18):\n",
    "    num_batches = len(train_data_loader)\n",
    "    for batch_index, (target, token_index) in tqdm(enumerate(train_data_loader)):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        target = target.to(device)\n",
    "        token_index = token_index.to(device)\n",
    "        step = num_batches*(epoch_index) + batch_index + 1   \n",
    "        output = model(token_index)\n",
    "        loss = F.nll_loss(F.log_softmax(output,dim=-1), target)\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度的正则进行截断，保证训练稳定\n",
    "        optimizer.step()\n",
    "        if batch_index % 300 == 0:\n",
    "            cur_epoch = epoch_index\n",
    "            logging.warning(f\"epoch_index: {epoch_index}, batch_index: {batch_index}, loss: {loss}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "782it [00:08, 92.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-03-17 01:01:44,283 (612167170:11) WARNING: eval_loss: 1.3197190761566162, eval_acc: 0.67968\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "total_acc_account = 0\n",
    "total_account = 0\n",
    "for eval_batch_index, (eval_target, eval_token_index) in tqdm(enumerate(eval_data_loader)):\n",
    "    eval_target = eval_target.to(device)\n",
    "    eval_token_index = eval_token_index.to(device)\n",
    "    total_account += eval_target.shape[0]\n",
    "    eval_logits = model(eval_token_index)\n",
    "    total_acc_account += (torch.argmax(eval_logits, dim=-1) == eval_target).sum().item()\n",
    "    eval_loss = F.nll_loss(F.log_softmax(eval_logits,dim=-1), eval_target)\n",
    "logging.warning(f\"eval_loss: {eval_loss}, eval_acc: {total_acc_account / total_account}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 当前情况下 log 存在训练不足，有一定提高空间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fake",
   "language": "python",
   "name": "fake"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
